{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "TripletMarginLossMNIST.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU",
    "widgets": {
      "application/vnd.jupyter.widget-state+json": {
        "209f931b8f7149ddb74aa15e7bbbffaf": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_d065b04db4094618b2f269c652f3dafb",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_35b08cbab3804f8cb152250e34d7b223",
              "IPY_MODEL_b37600fd1f094428b4deb7631ac1e9b3"
            ]
          }
        },
        "d065b04db4094618b2f269c652f3dafb": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "35b08cbab3804f8cb152250e34d7b223": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_0a57f8668c054b0295091a15430fb80f",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_67c6c025f33249a89af8914de267153f"
          }
        },
        "b37600fd1f094428b4deb7631ac1e9b3": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_0e29bb9866454c1380f16cc8ad4dc1cd",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 9920512/? [00:20&lt;00:00, 1479281.38it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_29e560d467c84b218efde2814a3b40b3"
          }
        },
        "0a57f8668c054b0295091a15430fb80f": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "67c6c025f33249a89af8914de267153f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "0e29bb9866454c1380f16cc8ad4dc1cd": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "29e560d467c84b218efde2814a3b40b3": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9282bf5452e24cde87670c1988feeea9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_b82457ac62254705abbf3eaace763ad1",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_83a94043b1dd4f879714d6a8779009b7",
              "IPY_MODEL_f89f93f87d274dd5b321a66c4d81bb3e"
            ]
          }
        },
        "b82457ac62254705abbf3eaace763ad1": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "83a94043b1dd4f879714d6a8779009b7": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_f4607e80d7cd409ca1a3fab452c86873",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "success",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_d71788a45fcb4a4f91d344f48670d71d"
          }
        },
        "f89f93f87d274dd5b321a66c4d81bb3e": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_d51f3cc8ac504af983f9d211e2439f16",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 32768/? [00:00&lt;00:00, 117375.99it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_e35be65fe9534f3081628cb04212520a"
          }
        },
        "f4607e80d7cd409ca1a3fab452c86873": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "d71788a45fcb4a4f91d344f48670d71d": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "d51f3cc8ac504af983f9d211e2439f16": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "e35be65fe9534f3081628cb04212520a": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "69e271528688464abaf0821454792932": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_fc359856bc2842d798509f18dff00175",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_3566784d3a254e508d31a1b3260c4c4b",
              "IPY_MODEL_4bbc79fee05c4977b4cd0e0ae8fae795"
            ]
          }
        },
        "fc359856bc2842d798509f18dff00175": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "3566784d3a254e508d31a1b3260c4c4b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_08a7ad1fa1c4475faa952a39f522dbb9",
            "_dom_classes": [],
            "description": "",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 1,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_383e93b843d743558129a1c582e7bd2c"
          }
        },
        "4bbc79fee05c4977b4cd0e0ae8fae795": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_207d073d941043838b67172eb10a339b",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 1654784/? [00:18&lt;00:00, 579946.35it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_93ee625991f8438ab84e71475b066c5f"
          }
        },
        "08a7ad1fa1c4475faa952a39f522dbb9": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "383e93b843d743558129a1c582e7bd2c": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "207d073d941043838b67172eb10a339b": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "93ee625991f8438ab84e71475b066c5f": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "cb729196bd49483d91875bdd62e57380": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HBoxModel",
          "state": {
            "_view_name": "HBoxView",
            "_dom_classes": [],
            "_model_name": "HBoxModel",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "box_style": "",
            "layout": "IPY_MODEL_3a8deddbd7774d0bb55d868c7ac48fde",
            "_model_module": "@jupyter-widgets/controls",
            "children": [
              "IPY_MODEL_b086304f42784f7b8f67b8d7b01b294d",
              "IPY_MODEL_b3d8c4815c9446d08669d82fa55bd352"
            ]
          }
        },
        "3a8deddbd7774d0bb55d868c7ac48fde": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "b086304f42784f7b8f67b8d7b01b294d": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "FloatProgressModel",
          "state": {
            "_view_name": "ProgressView",
            "style": "IPY_MODEL_98fb70e8faac4ff69570a67aa8a22eb8",
            "_dom_classes": [],
            "description": "  0%",
            "_model_name": "FloatProgressModel",
            "bar_style": "info",
            "max": 1,
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": 0,
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "orientation": "horizontal",
            "min": 0,
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_3b2d933132754d77aaf0e224537fdbba"
          }
        },
        "b3d8c4815c9446d08669d82fa55bd352": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "HTMLModel",
          "state": {
            "_view_name": "HTMLView",
            "style": "IPY_MODEL_9d2a429215084021bfa96e58c36c7946",
            "_dom_classes": [],
            "description": "",
            "_model_name": "HTMLModel",
            "placeholder": "​",
            "_view_module": "@jupyter-widgets/controls",
            "_model_module_version": "1.5.0",
            "value": " 0/4542 [00:00&lt;?, ?it/s]",
            "_view_count": null,
            "_view_module_version": "1.5.0",
            "description_tooltip": null,
            "_model_module": "@jupyter-widgets/controls",
            "layout": "IPY_MODEL_193e0f7046084498abdd27c2b6304597"
          }
        },
        "98fb70e8faac4ff69570a67aa8a22eb8": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "ProgressStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "ProgressStyleModel",
            "description_width": "initial",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "bar_color": null,
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "3b2d933132754d77aaf0e224537fdbba": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        },
        "9d2a429215084021bfa96e58c36c7946": {
          "model_module": "@jupyter-widgets/controls",
          "model_name": "DescriptionStyleModel",
          "state": {
            "_view_name": "StyleView",
            "_model_name": "DescriptionStyleModel",
            "description_width": "",
            "_view_module": "@jupyter-widgets/base",
            "_model_module_version": "1.5.0",
            "_view_count": null,
            "_view_module_version": "1.2.0",
            "_model_module": "@jupyter-widgets/controls"
          }
        },
        "193e0f7046084498abdd27c2b6304597": {
          "model_module": "@jupyter-widgets/base",
          "model_name": "LayoutModel",
          "state": {
            "_view_name": "LayoutView",
            "grid_template_rows": null,
            "right": null,
            "justify_content": null,
            "_view_module": "@jupyter-widgets/base",
            "overflow": null,
            "_model_module_version": "1.2.0",
            "_view_count": null,
            "flex_flow": null,
            "width": null,
            "min_width": null,
            "border": null,
            "align_items": null,
            "bottom": null,
            "_model_module": "@jupyter-widgets/base",
            "top": null,
            "grid_column": null,
            "overflow_y": null,
            "overflow_x": null,
            "grid_auto_flow": null,
            "grid_area": null,
            "grid_template_columns": null,
            "flex": null,
            "_model_name": "LayoutModel",
            "justify_items": null,
            "grid_row": null,
            "max_height": null,
            "align_content": null,
            "visibility": null,
            "align_self": null,
            "height": null,
            "min_height": null,
            "padding": null,
            "grid_auto_rows": null,
            "grid_gap": null,
            "max_width": null,
            "order": null,
            "_view_module_version": "1.2.0",
            "grid_template_areas": null,
            "object_position": null,
            "object_fit": null,
            "grid_auto_columns": null,
            "margin": null,
            "display": null,
            "left": null
          }
        }
      }
    }
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "lMde13VLDjiH",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "5522029e-b3e0-4628-b4d4-4c1e4cd73e1a"
      },
      "source": [
        "!pip install pytorch-metric-learning\n",
        "!pip install faiss-gpu"
      ],
      "execution_count": 1,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Collecting pytorch-metric-learning\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/90/3b/c44b7dc2743270b0e5536d685d3ec6d0437e1af22bee940c3560dd21207b/pytorch_metric_learning-0.9.94-py3-none-any.whl (96kB)\n",
            "\r\u001b[K     |███▍                            | 10kB 21.3MB/s eta 0:00:01\r\u001b[K     |██████▊                         | 20kB 15.5MB/s eta 0:00:01\r\u001b[K     |██████████▏                     | 30kB 13.3MB/s eta 0:00:01\r\u001b[K     |█████████████▌                  | 40kB 12.3MB/s eta 0:00:01\r\u001b[K     |█████████████████               | 51kB 9.2MB/s eta 0:00:01\r\u001b[K     |████████████████████▎           | 61kB 8.3MB/s eta 0:00:01\r\u001b[K     |███████████████████████▊        | 71kB 9.3MB/s eta 0:00:01\r\u001b[K     |███████████████████████████     | 81kB 10.3MB/s eta 0:00:01\r\u001b[K     |██████████████████████████████▌ | 92kB 9.4MB/s eta 0:00:01\r\u001b[K     |████████████████████████████████| 102kB 6.1MB/s \n",
            "\u001b[?25hRequirement already satisfied: torchvision in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (0.8.1+cu101)\n",
            "Requirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (1.18.5)\n",
            "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (0.22.2.post1)\n",
            "Requirement already satisfied: torch>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (1.7.0+cu101)\n",
            "Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from pytorch-metric-learning) (4.41.1)\n",
            "Requirement already satisfied: pillow>=4.1.1 in /usr/local/lib/python3.6/dist-packages (from torchvision->pytorch-metric-learning) (7.0.0)\n",
            "Requirement already satisfied: joblib>=0.11 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->pytorch-metric-learning) (0.17.0)\n",
            "Requirement already satisfied: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-learn->pytorch-metric-learning) (1.4.1)\n",
            "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.6/dist-packages (from torch>=1.6.0->pytorch-metric-learning) (3.7.4.3)\n",
            "Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from torch>=1.6.0->pytorch-metric-learning) (0.16.0)\n",
            "Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from torch>=1.6.0->pytorch-metric-learning) (0.8)\n",
            "Installing collected packages: pytorch-metric-learning\n",
            "Successfully installed pytorch-metric-learning-0.9.94\n",
            "Collecting faiss-gpu\n",
            "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/32/8b29e3f99224f24716257e78724a02674761e034e6920b4278cc21d19f77/faiss_gpu-1.6.5-cp36-cp36m-manylinux2014_x86_64.whl (67.6MB)\n",
            "\u001b[K     |████████████████████████████████| 67.7MB 43kB/s \n",
            "\u001b[?25hInstalling collected packages: faiss-gpu\n",
            "Successfully installed faiss-gpu-1.6.5\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "GJ_L0TrTDnEA",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 668,
          "referenced_widgets": [
            "209f931b8f7149ddb74aa15e7bbbffaf",
            "d065b04db4094618b2f269c652f3dafb",
            "35b08cbab3804f8cb152250e34d7b223",
            "b37600fd1f094428b4deb7631ac1e9b3",
            "0a57f8668c054b0295091a15430fb80f",
            "67c6c025f33249a89af8914de267153f",
            "0e29bb9866454c1380f16cc8ad4dc1cd",
            "29e560d467c84b218efde2814a3b40b3",
            "9282bf5452e24cde87670c1988feeea9",
            "b82457ac62254705abbf3eaace763ad1",
            "83a94043b1dd4f879714d6a8779009b7",
            "f89f93f87d274dd5b321a66c4d81bb3e",
            "f4607e80d7cd409ca1a3fab452c86873",
            "d71788a45fcb4a4f91d344f48670d71d",
            "d51f3cc8ac504af983f9d211e2439f16",
            "e35be65fe9534f3081628cb04212520a",
            "69e271528688464abaf0821454792932",
            "fc359856bc2842d798509f18dff00175",
            "3566784d3a254e508d31a1b3260c4c4b",
            "4bbc79fee05c4977b4cd0e0ae8fae795",
            "08a7ad1fa1c4475faa952a39f522dbb9",
            "383e93b843d743558129a1c582e7bd2c",
            "207d073d941043838b67172eb10a339b",
            "93ee625991f8438ab84e71475b066c5f",
            "cb729196bd49483d91875bdd62e57380",
            "3a8deddbd7774d0bb55d868c7ac48fde",
            "b086304f42784f7b8f67b8d7b01b294d",
            "b3d8c4815c9446d08669d82fa55bd352",
            "98fb70e8faac4ff69570a67aa8a22eb8",
            "3b2d933132754d77aaf0e224537fdbba",
            "9d2a429215084021bfa96e58c36c7946",
            "193e0f7046084498abdd27c2b6304597"
          ]
        },
        "outputId": "c59d154f-16a0-41e2-e516-c9e8120a78ee"
      },
      "source": [
        "from pytorch_metric_learning import losses, miners, distances, reducers, testers\n",
        "from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "from torchvision import datasets\n",
        "import torch\n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "import torch.optim as optim\n",
        "from torchvision import datasets, transforms\n",
        "import numpy as np\n",
        "\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "class Net(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Net, self).__init__()\n",
        "        self.conv1 = nn.Conv2d(1, 32, 3, 1)\n",
        "        self.conv2 = nn.Conv2d(32, 64, 3, 1)\n",
        "        self.dropout1 = nn.Dropout2d(0.25)\n",
        "        self.dropout2 = nn.Dropout2d(0.5)\n",
        "        self.fc1 = nn.Linear(9216, 128)\n",
        "\n",
        "    def forward(self, x):\n",
        "        x = self.conv1(x)\n",
        "        x = F.relu(x)\n",
        "        x = self.conv2(x)\n",
        "        x = F.relu(x)\n",
        "        x = F.max_pool2d(x, 2)\n",
        "        x = self.dropout1(x)\n",
        "        x = torch.flatten(x, 1)\n",
        "        x = self.fc1(x)\n",
        "        return x\n",
        "\n",
        "### MNIST code originally from https://github.com/pytorch/examples/blob/master/mnist/main.py ### \n",
        "def train(model, loss_func, mining_func, device, train_loader, optimizer, epoch):\n",
        "    model.train()\n",
        "    for batch_idx, (data, labels) in enumerate(train_loader):\n",
        "        data, labels = data.to(device), labels.to(device)\n",
        "        optimizer.zero_grad()\n",
        "        embeddings = model(data)\n",
        "        indices_tuple = mining_func(embeddings, labels)\n",
        "        loss = loss_func(embeddings, labels, indices_tuple)\n",
        "        loss.backward()\n",
        "        optimizer.step()\n",
        "        if batch_idx % 20 == 0:\n",
        "            print(\"Epoch {} Iteration {}: Loss = {}, Number of mined triplets = {}\".format(epoch, batch_idx, loss, mining_func.num_triplets))\n",
        "\n",
        "### convenient function from pytorch-metric-learning ###\n",
        "def get_all_embeddings(dataset, model):\n",
        "    tester = testers.BaseTester()\n",
        "    return tester.get_all_embeddings(dataset, model)\n",
        "\n",
        "### compute accuracy using AccuracyCalculator from pytorch-metric-learning ###\n",
        "def test(train_set, test_set, model, accuracy_calculator):\n",
        "    train_embeddings, train_labels = get_all_embeddings(train_set, model)\n",
        "    test_embeddings, test_labels = get_all_embeddings(test_set, model)\n",
        "    print(\"Computing accuracy\")\n",
        "    accuracies = accuracy_calculator.get_accuracy(test_embeddings, \n",
        "                                                train_embeddings,\n",
        "                                                np.squeeze(test_labels),\n",
        "                                                np.squeeze(train_labels),\n",
        "                                                False)\n",
        "    print(\"Test set accuracy (Precision@1) = {}\".format(accuracies[\"precision_at_1\"]))\n",
        "\n",
        "device = torch.device(\"cuda\")\n",
        "\n",
        "transform = transforms.Compose([\n",
        "        transforms.ToTensor(),\n",
        "        transforms.Normalize((0.1307,), (0.3081,))\n",
        "    ])\n",
        "\n",
        "batch_size = 256\n",
        "\n",
        "dataset1 = datasets.MNIST('.', train=True, download=True, transform=transform)\n",
        "dataset2 = datasets.MNIST('.', train=False, transform=transform)\n",
        "train_loader = torch.utils.data.DataLoader(dataset1, batch_size=256, shuffle=True)\n",
        "test_loader = torch.utils.data.DataLoader(dataset2, batch_size=256)\n",
        "\n",
        "model = Net().to(device)\n",
        "optimizer = optim.Adam(model.parameters(), lr=0.01)\n",
        "num_epochs = 1\n",
        "\n",
        "\n",
        "### pytorch-metric-learning stuff ###\n",
        "distance = distances.CosineSimilarity()\n",
        "reducer = reducers.ThresholdReducer(low = 0)\n",
        "loss_func = losses.TripletMarginLoss(margin = 0.2, distance = distance, reducer = reducer)\n",
        "mining_func = miners.TripletMarginMiner(margin = 0.2, distance = distance, type_of_triplets = \"semihard\")\n",
        "accuracy_calculator = AccuracyCalculator(include = (\"precision_at_1\",), k = 1)\n",
        "### pytorch-metric-learning stuff ###\n",
        "\n",
        "\n",
        "for epoch in range(1, num_epochs+1):\n",
        "    train(model, loss_func, mining_func, device, train_loader, optimizer, epoch)\n",
        "    test(dataset1, dataset2, model, accuracy_calculator)\n",
        "\n"
      ],
      "execution_count": 2,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "209f931b8f7149ddb74aa15e7bbbffaf",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./MNIST/raw/train-images-idx3-ubyte.gz to ./MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "9282bf5452e24cde87670c1988feeea9",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz to ./MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz\n",
            "\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "69e271528688464abaf0821454792932",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./MNIST/raw/t10k-images-idx3-ubyte.gz to ./MNIST/raw\n",
            "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to ./MNIST/raw/t10k-labels-idx1-ubyte.gz\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "display_data",
          "data": {
            "application/vnd.jupyter.widget-view+json": {
              "model_id": "cb729196bd49483d91875bdd62e57380",
              "version_minor": 0,
              "version_major": 2
            },
            "text/plain": [
              "HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "stream",
          "text": [
            "Extracting ./MNIST/raw/t10k-labels-idx1-ubyte.gz to ./MNIST/raw\n",
            "Processing...\n",
            "Done!\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.6/dist-packages/torchvision/datasets/mnist.py:480: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /pytorch/torch/csrc/utils/tensor_numpy.cpp:141.)\n",
            "  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Epoch 1 Iteration 0: Loss = 0.1037939116358757, Number of mined triplets = 776687\n",
            "Epoch 1 Iteration 20: Loss = 0.09549278020858765, Number of mined triplets = 123048\n",
            "Epoch 1 Iteration 40: Loss = 0.08915364742279053, Number of mined triplets = 89530\n",
            "Epoch 1 Iteration 60: Loss = 0.08617275953292847, Number of mined triplets = 50026\n",
            "Epoch 1 Iteration 80: Loss = 0.08428513258695602, Number of mined triplets = 42704\n",
            "Epoch 1 Iteration 100: Loss = 0.08503272384405136, Number of mined triplets = 44915\n",
            "Epoch 1 Iteration 120: Loss = 0.08499237149953842, Number of mined triplets = 34030\n",
            "Epoch 1 Iteration 140: Loss = 0.08733508735895157, Number of mined triplets = 45407\n",
            "Epoch 1 Iteration 160: Loss = 0.08337089419364929, Number of mined triplets = 32317\n",
            "Epoch 1 Iteration 180: Loss = 0.08752584457397461, Number of mined triplets = 40344\n",
            "Epoch 1 Iteration 200: Loss = 0.08363424241542816, Number of mined triplets = 27893\n",
            "Epoch 1 Iteration 220: Loss = 0.08256012946367264, Number of mined triplets = 19185\n"
          ],
          "name": "stdout"
        },
        {
          "output_type": "stream",
          "text": [
            "100%|██████████| 1875/1875 [00:15<00:00, 119.05it/s]\n",
            "100%|██████████| 313/313 [00:03<00:00, 78.37it/s] \n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "Computing accuracy\n",
            "Test set accuracy (Precision@1) = 0.9789\n"
          ],
          "name": "stdout"
        }
      ]
    }
  ]
}